export MASTER_PORT=23456
export WANDB_PROJECT=dynamiccot

MODEL_PATH=$1
OUTPUT_DIR=$2

lower_model_path=$(echo "$MODEL_PATH" | tr '[:upper:]' '[:lower:]')

if [ ! -f "./tmp.txt" ]; then
    pip3 install -e ".[torch,metrics,deepspeed]"
    if [[ "$lower_model_path" == *internvl* ]]; then
        pip3 install transformers==4.52.1
    else
        pip3 install transformers==4.49.0
    fi
    pip3 install vllm==0.7.3 --user
    touch ./tmp.txt
fi

llamafactory-cli train \
    --stage sft \
    --do_train True \
    --model_name_or_path ${MODEL_PATH} \
    --freeze_vision_tower True \
    --freeze_multi_modal_projector False \
    --preprocessing_num_workers 16 \
    --finetuning_type full \
    --template qwen2_vl \
    --flash_attn auto \
    --dataset_dir data \
    --dataset kp_ocr \
    --cutoff_len 2048 \
    --learning_rate 5e-05 \
    --num_train_epochs 5.0 \
    --max_samples 100000 \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --lr_scheduler_type cosine \
    --max_grad_norm 1.0 \
    --logging_steps 5 \
    --save_steps 1000 \
    --warmup_ratio 0.1 \
    --optim adamw_torch \
    --packing False \
    --report_to wandb \
    --run_name qwen2.5-vl-3b-full-sft \
    --output_dir ${OUTPUT_DIR} \
    --overwrite_output_dir True \
    --bf16 True \
    --plot_loss True \
    --ddp_timeout 180000000 \
    --include_num_input_tokens_seen True \
    --val_size 0.1 \
    --eval_steps 50 \
    ${@:3}